{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train with NeMo\n",
    "\n",
    "[Neural Modules (NeMo)](https://nvidia.github.io/NeMo/index.html) is a framework-agnostic toolkit for building AI applications. It currently supports the PyTorch framework.\n",
    "\n",
    "Using NeMo to train a PyTorch model is simple. In this notebook, we will demonstrate how to use NeMo to train the Asian Barrier Option pricing model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Defining the trainable module is similar to defining a PyTorch module but it defines the input and output ports:-"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting nemo_model.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile nemo_model.py\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch\n",
    "from nemo.core.neural_types import BatchTag, ChannelTag, NeuralType, AxisType\n",
    "import nemo\n",
    "\n",
    "class Net(nemo.backends.pytorch.nm.TrainableNM):\n",
    "#class Net(nn.Module):\n",
    "    @staticmethod\n",
    "    def create_ports():\n",
    "        input_ports = {\"x\": NeuralType({0: AxisType(BatchTag),\n",
    "                                        1: AxisType(ChannelTag, 6)})}\n",
    "        output_ports = {\"y_pred\": NeuralType({0: AxisType(BatchTag),\n",
    "                                              1: AxisType(ChannelTag, 1)})}\n",
    "        return input_ports, output_ports\n",
    "\n",
    "    def __init__(self, hidden=512, **kwargs):\n",
    "        super(Net, self).__init__(**kwargs)\n",
    "        self.fc1 = nn.Linear(6, hidden)\n",
    "        self.fc2 = nn.Linear(hidden, hidden)\n",
    "        self.fc3 = nn.Linear(hidden, hidden)\n",
    "        self.fc4 = nn.Linear(hidden, hidden)\n",
    "        self.fc5 = nn.Linear(hidden, hidden)\n",
    "        self.fc6 = nn.Linear(hidden, 1)\n",
    "        self.register_buffer('norm',\n",
    "                             torch.tensor([200.0,\n",
    "                                           198.0,\n",
    "                                           200.0,\n",
    "                                           0.4,\n",
    "                                           0.2,\n",
    "                                           0.2]))\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x / self.norm\n",
    "        x = F.elu(self.fc1(x))\n",
    "        x = F.elu(self.fc2(x))\n",
    "        x = F.elu(self.fc3(x))\n",
    "        x = F.elu(self.fc4(x))\n",
    "        x = F.elu(self.fc5(x))\n",
    "        return self.fc6(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The NeMo DataLayer module is wrapped around the normal PyTorch Dataset:-"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing nemo_datalayer.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile nemo_datalayer.py\n",
    "import torch\n",
    "import nemo\n",
    "from nemo.core.neural_types import BatchTag, ChannelTag, NeuralType, AxisType\n",
    "\n",
    "\n",
    "class OptionDataSet(torch.utils.data.Dataset):\n",
    "    def __init__(self, filename, rank=0, world_size=5):\n",
    "        tensor = torch.load(filename)\n",
    "        self.tensor = (tensor[0], tensor[1])\n",
    "        self.length = len(self.tensor[0]) // world_size\n",
    "        self.world_size = world_size\n",
    "        self.rank = rank\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        index = index * self.world_size + self.rank\n",
    "\n",
    "        return self.tensor[0][index], self.tensor[1][index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.length\n",
    "\n",
    "class OptionDataLayer(nemo.backends.pytorch.nm.DataLayerNM):\n",
    "    @staticmethod\n",
    "    def create_ports():\n",
    "        # Note: we define the size of the height and width of our output\n",
    "        # tensors, and thus require a size parameter.\n",
    "        input_ports = {}\n",
    "        output_ports = {\n",
    "            \"x\": NeuralType({0: AxisType(BatchTag),\n",
    "                                 1: AxisType(ChannelTag, 6)}),\n",
    "            \"ground\": NeuralType({0: AxisType(BatchTag)})\n",
    "        }\n",
    "        return input_ports, output_ports\n",
    "\n",
    "    def __init__(self, filename, rank=0, world_size=5, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self._dataset = OptionDataSet(filename, rank, world_size)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self._dataset)\n",
    "\n",
    "    @property\n",
    "    def dataset(self):\n",
    "        return self._dataset\n",
    "\n",
    "    @property\n",
    "    def data_iterator(self):\n",
    "        return None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define the Loss Neural Module as following, which wraps around the PyTorch MSELoss with added input and output types:-"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting nemo_losslayer.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile nemo_losslayer.py\n",
    "import nemo\n",
    "from nemo.core.neural_types import BatchTag, ChannelTag, NeuralType, AxisType\n",
    "import torch\n",
    "\n",
    "class MSELoss(nemo.backends.pytorch.nm.LossNM):\n",
    "    @staticmethod\n",
    "    def create_ports():\n",
    "        input_ports = {\"y_pred\": NeuralType({0: AxisType(BatchTag),\n",
    "                                             1: AxisType(ChannelTag, 1)}),\n",
    "                       \"ground\": NeuralType({0: AxisType(BatchTag)})}\n",
    "        output_ports = {\"loss\": NeuralType(None)}\n",
    "        return input_ports, output_ports\n",
    "\n",
    "    def __init__(self, **kwargs):\n",
    "        # Neural Module API specific\n",
    "        super().__init__(**kwargs)\n",
    "        # End of Neural Module API specific\n",
    "        self._loss = torch.nn.MSELoss()\n",
    "\n",
    "    # You need to implement this function\n",
    "    def _loss_function(self, **kwargs):\n",
    "        v = self._loss(kwargs['y_pred'][:,0], kwargs['ground'])\n",
    "        return v"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To use Neural Modules, we need to following 3 steps:-\n",
    "\n",
    "1. Creation of NeuralModuleFactory and necessary NeuralModule\n",
    "2. Defining a Directed Acyclic Graph (DAG) of NeuralModule\n",
    "3. Call to “action” such as train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2019-11-18 21:52:15,176 - WARNING - Data Layer does not have any weights to return. This get_weights call returns None.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting .....\n",
      "Starting epoch 0\n",
      "Step: 0\n",
      "Train Loss: 2603.74560546875\n",
      "Step time: 0.34972667694091797 seconds\n",
      "Finished epoch 0 in 1.576570987701416\n",
      "Starting epoch 1\n",
      "Step: 25\n",
      "Train Loss: 2615.87451171875\n",
      "Step time: 0.0040209293365478516 seconds\n",
      "Finished epoch 1 in 1.1546299457550049\n",
      "Starting epoch 2\n",
      "Step: 50\n",
      "Train Loss: 374.9800109863281\n",
      "Step time: 0.004566669464111328 seconds\n",
      "Finished epoch 2 in 1.1191346645355225\n",
      "Starting epoch 3\n",
      "Finished epoch 3 in 1.0970311164855957\n",
      "Starting epoch 4\n",
      "Step: 75\n",
      "Train Loss: 39.58891677856445\n",
      "Step time: 0.0044879913330078125 seconds\n",
      "Finished epoch 4 in 1.1738121509552002\n",
      "Starting epoch 5\n",
      "Step: 100\n",
      "Train Loss: 9.877204895019531\n",
      "Step time: 0.004233360290527344 seconds\n",
      "Finished epoch 5 in 1.15018892288208\n",
      "Starting epoch 6\n",
      "Step: 125\n",
      "Train Loss: 1.8599201440811157\n",
      "Step time: 0.005522727966308594 seconds\n",
      "Finished epoch 6 in 1.2052836418151855\n",
      "Starting epoch 7\n",
      "Finished epoch 7 in 1.197835922241211\n",
      "Starting epoch 8\n"
     ]
    }
   ],
   "source": [
    "import nemo\n",
    "from nemo.core import DeviceType\n",
    "from nemo_model import Net\n",
    "from nemo_datalayer import OptionDataLayer\n",
    "from nemo_losslayer import MSELoss\n",
    "nf = nemo.core.NeuralModuleFactory()\n",
    "# nf = nemo.core.NeuralModuleFactory()\n",
    "dl= OptionDataLayer('trn.pth', 0, 1, batch_size=32)\n",
    "\n",
    "# instantiate necessary neural modules\n",
    "fx = Net(hidden=512).cuda() #, placement=DeviceType.GPU)\n",
    "loss = MSELoss()\n",
    "\n",
    "# describe activation's flow\n",
    "x, y = dl()\n",
    "p = fx(x=x)\n",
    "lss = loss(y_pred=p, ground=y)\n",
    "\n",
    "# SimpleLossLoggerCallback will print loss values to console.\n",
    "callback = nemo.core.SimpleLossLoggerCallback(\n",
    "    tensors=[lss],\n",
    "    print_func=lambda x: print(f'Train Loss: {str(x[0].item())}'))\n",
    "\n",
    "# Invoke \"train\" action\n",
    "nf.train([lss], callbacks=[callback],\n",
    "         optimization_params={\"num_epochs\": 20, \"lr\": 0.0003},\n",
    "         optimizer=\"adam\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "NVIDIA Volta and Turing GPUs have Tensor Cores which can do fast matrix multiplications with values in float16 format. To enable mixed-precision in NeMo all you need to do is to set the optimization_level parameter of nemo.core.NeuralModuleFactory to nemo.core.Optimization.mxprO1. For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nf = nemo.core.NeuralModuleFactory(optimization_level=nemo.core.Optimization.mxprO1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For multi-GPU training, follow two steps in NeMo:\n",
    "1. Set placement to nemo.core.DeviceType.AllGpu in NeuralModuleFactory\n",
    "2. Add the ‘local_rank’ argument to your script and do not set it yourself: parser.add_argument(“–local_rank”, default=None, type=int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting nemo_dis_train.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile nemo_dis_train.py\n",
    "import nemo\n",
    "from nemo.core import DeviceType\n",
    "from nemo_model import Net\n",
    "from nemo_datalayer import OptionDataLayer\n",
    "from nemo_losslayer import MSELoss\n",
    "import argparse\n",
    "import os\n",
    "parser = argparse.ArgumentParser(description='ResNet50 on ImageNet')\n",
    "parser.add_argument(\"--local_rank\", default=None, type=int)\n",
    "\n",
    "args = parser.parse_args()\n",
    "\n",
    "if args.local_rank is not None:\n",
    "    device = nemo.core.DeviceType.AllGpu\n",
    "else:\n",
    "    device = nemo.core.DeviceType.GPU\n",
    "    \n",
    "world_size = int(os.environ['WORLD_SIZE'])\n",
    "\n",
    "nf = nemo.core.NeuralModuleFactory(backend=nemo.core.Backend.PyTorch,\n",
    "    local_rank=args.local_rank,\n",
    "    placement=device,                               \n",
    "    optimization_level=nemo.core.Optimization.mxprO1)\n",
    "# nf = nemo.core.NeuralModuleFactory()\n",
    "dl= OptionDataLayer('trn.pth', args.local_rank, world_size, batch_size=32)\n",
    "\n",
    "# instantiate necessary neural modules\n",
    "# RealFunctionDataLayer defaults to f=torch.sin, sampling from x=[-4, 4]\n",
    "fx = Net(hidden=512).cuda() #, placement=DeviceType.GPU)\n",
    "loss = MSELoss()\n",
    "\n",
    "# describe activation's flow\n",
    "x, y = dl()\n",
    "p = fx(x=x)\n",
    "lss = loss(y_pred=p, ground=y)\n",
    "\n",
    "# SimpleLossLoggerCallback will print loss values to console.\n",
    "callback = nemo.core.SimpleLossLoggerCallback(\n",
    "    tensors=[lss],\n",
    "    print_func=lambda x: print(f'Train Loss: {str(x[0].item())}'))\n",
    "\n",
    "# Invoke \"train\" action\n",
    "nf.train([lss], callbacks=[callback],\n",
    "         optimization_params={\"num_epochs\": 20, \"lr\": 0.0003},\n",
    "         optimizer=\"adam\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "*****************************************\n",
      "Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n",
      "*****************************************\n",
      "WARNING:root:Data Layer does not have any weights to return. This get_weights call returns None.\n",
      "Doing distributed training\n",
      "WARNING:root:Data Layer does not have any weights to return. This get_weights call returns None.\n",
      "Doing distributed training\n",
      "WARNING:root:Data Layer does not have any weights to return. This get_weights call returns None.\n",
      "Doing distributed training\n",
      "2019-11-18 21:59:47,135 - WARNING - Data Layer does not have any weights to return. This get_weights call returns None.\n",
      "Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.\n",
      "\n",
      "Defaults for this optimization level are:\n",
      "enabled                : True\n",
      "opt_level              : O1\n",
      "cast_model_type        : None\n",
      "patch_torch_functions  : True\n",
      "keep_batchnorm_fp32    : None\n",
      "master_weights         : None\n",
      "loss_scale             : dynamic\n",
      "Processing user overrides (additional kwargs that are not None)...\n",
      "After processing overrides, optimization options are:\n",
      "enabled                : True\n",
      "opt_level              : O1\n",
      "cast_model_type        : None\n",
      "patch_torch_functions  : True\n",
      "keep_batchnorm_fp32    : None\n",
      "master_weights         : None\n",
      "loss_scale             : dynamic\n",
      "Doing distributed training\n",
      "Starting .....\n",
      "Starting epoch 0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0\n",
      "Step: 0\n",
      "Train Loss: 3221.57763671875\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0\n",
      "Step time: 0.6914923191070557 seconds\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0\n",
      "Finished epoch 0 in 1.6740753650665283\n",
      "Starting epoch 1\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0\n",
      "Finished epoch 1 in 1.2843654155731201\n",
      "Starting epoch 2\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2048.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2048.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2048.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2048.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 1024.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 1024.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 1024.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 1024.0\n",
      "Finished epoch 2 in 1.2458338737487793\n",
      "Starting epoch 3\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 512.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 512.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 512.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 512.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 256.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 256.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 256.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 256.0\n",
      "Finished epoch 3 in 1.2108018398284912\n",
      "Starting epoch 4\n",
      "Finished epoch 4 in 1.2647509574890137\n",
      "Starting epoch 5\n",
      "Finished epoch 5 in 1.2803404331207275\n",
      "Starting epoch 6\n",
      "Finished epoch 6 in 1.2259752750396729\n",
      "Starting epoch 7\n",
      "Finished epoch 7 in 1.1974444389343262\n",
      "Starting epoch 8\n",
      "Finished epoch 8 in 1.198972225189209\n",
      "Starting epoch 9\n",
      "Finished epoch 9 in 1.2113327980041504\n",
      "Starting epoch 10\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 128.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 128.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 128.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 128.0\n",
      "Finished epoch 10 in 1.2995426654815674\n",
      "Starting epoch 11\n",
      "Finished epoch 11 in 1.2914905548095703\n",
      "Starting epoch 12\n",
      "Step: 25\n",
      "Train Loss: 995.6943969726562\n",
      "Step time: 0.00768280029296875 seconds\n",
      "Finished epoch 12 in 1.258274793624878\n",
      "Starting epoch 13\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 64.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 64.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 64.0\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 64.0\n",
      "Finished epoch 13 in 1.2511169910430908\n",
      "Starting epoch 14\n",
      "Finished epoch 14 in 1.2451517581939697\n",
      "Starting epoch 15\n",
      "Finished epoch 15 in 1.2092945575714111\n",
      "Starting epoch 16\n",
      "Finished epoch 16 in 1.2507727146148682\n",
      "Starting epoch 17\n",
      "Finished epoch 17 in 1.2585015296936035\n",
      "Starting epoch 18\n",
      "Finished epoch 18 in 1.2153122425079346\n",
      "Starting epoch 19\n",
      "Finished epoch 19 in 1.230492353439331\n",
      "Done in 25.3052020072937\n"
     ]
    }
   ],
   "source": [
    "!python -m torch.distributed.launch --nproc_per_node=4 nemo_dis_train.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The [callback API](https://nvidia.github.io/NeMo/tutorials/callbacks.html) makes setting up check points and evaluating the validation dataset easy. Interested readers please check the document for details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
